% Get and do inference for individual specific treatment effects in JTPA data

%% Organize data
% Load data and create interactions
createJTPAdata ;

% Keep track of some sample size numbers
[n,px] = size(X);

% Big design matrix will all treatment interactions
Z = [ones(n,1) , X , d , (d*ones(1,px)).*X];

% Subset of treatment observations
yd1 = y(d == 1,:);
Xd1 = X(d == 1,:);

% Subset of control observations
yd0 = y(d == 0,:);
Xd0 = X(d == 0,:);

%% OLS with everything
bolsJTPA = Z\y;
[solsJTPA,VolsJTPA] = hetero_se(Z,y-Z*bolsJTPA,inv(Z'*Z));

%% Lasso with plug-in tuning
% Control observations
% Initialize with 5 most marginally correlated variables
czy0 = corr(yd0,Xd0);
czys = sort(-abs(czy0));
suppy0 = find(abs(czy0) >= -czys(5));
tmpy0 = (Xd0(:,suppy0)-ones(sum(d==0),1)*mean(Xd0(:,suppy0)))\(yd0-mean(yd0));
by0 = zeros(size(Xd0,2),1);
by0(suppy0) = tmpy0;
blas0 = feasiblePostLasso(yd0-mean(yd0),Xd0-ones(sum(d==0),1)*mean(Xd0),...
    'MaxIter',1,'beta0',by0);
blas0 = [(mean(yd0)-mean(Xd0)*blas0);blas0];
use0 = blas0 ~= 0;

% Treatment observations
czy1 = corr(yd1,Xd1);
czys = sort(-abs(czy1));
suppy1 = find(abs(czy1) >= -czys(5));
tmpy1 = (Xd1(:,suppy1)-ones(sum(d),1)*mean(Xd1(:,suppy1)))\(yd1-mean(yd1));
by1 = zeros(size(Xd1,2),1);
by1(suppy1) = tmpy1;
blas1 = feasiblePostLasso(yd1-mean(yd1),Xd1-ones(sum(d),1)*mean(Xd1),...
    'MaxIter',1,'beta0',by1);
blas1 = [(mean(yd1)-mean(Xd1)*blas1);blas1];
use1 = blas1 ~= 0;

% Post-lasso
usepl = union(find(use0),find(use1));
usepl = [usepl ; (usepl+px+1)];
Zpl = Z(:,usepl);
bpl = Zpl\y;
[sepl,Vpl] = hetero_se(Zpl,y-Zpl*bpl,inv(Zpl'*Zpl));

bplJTPA = zeros(size(bolsJTPA));
bplJTPA(usepl) = bpl;
splJTPA = zeros(size(bolsJTPA));
splJTPA(usepl) = sepl;
VplJTPA = zeros(size(VolsJTPA));
VplJTPA(usepl,usepl) = Vpl;

%% Conditional average treatment effects

Xte = [zeros(n,(px+1)) , ones(n,1) , X];

% OLS estimates
teolsJTPA = Xte*bolsJTPA;
steolsJTPA = sqrt(diag(Xte*VolsJTPA*Xte'));

% Post-lasso estimates
teplJTPA = Xte*bplJTPA;
steplJTPA = sqrt(diag(Xte*VplJTPA*Xte'));

% TU estimates
TUteJTPAlb = zeros(n,ceil(log(n)));
TUteJTPAub = zeros(n,ceil(log(n)));
parfor ii = 1:n
    [TUteJTPAlb(ii,:),TUteJTPAub(ii,:)] = fsel_linear(Z,y,Xte(ii,:)',usepl,ceil(log(n)));    
end
    
%% save output

save JTPA_CATE ;

%% Make plot of estimated effects
% Plots sorted by treatment effect estimate
olsres = [teolsJTPA steolsJTPA];
olsres = sortrows(olsres,1);

plres = [teplJTPA steplJTPA];
plres = sortrows(plres,1);

fsreslb = [teplJTPA TUteJTPAlb];
fsreslb = sortrows(fsreslb,1);
fsresub = [teplJTPA TUteJTPAub];
fsresub = sortrows(fsresub,1);

figure; plot(1:n,olsres(:,1),'b',1:n,olsres(:,1)-1.96*olsres(:,2),'r',1:n,olsres(:,1)+1.96*olsres(:,2),'r');
title('OLS Pointwise 95% CI for CATE(x)');
savefig('JTPACATE_OLS');

figure; plot(1:n,plres(:,1),'b',1:n,plres(:,1)-1.96*plres(:,2),'r',1:n,plres(:,1)+1.96*plres(:,2),'r');
title('Post-Lasso Pointwise 95% CI for CATE(x)');
savefig('JTPACATE_PL');

for ii = 1:9
    figure; plot(1:n,plres(:,1),'b',1:n,fsreslb(:,1+ii),'r',1:n,fsresub(:,1+ii),'r');
    title(sprintf('FS(%d) Pointwise 95% CI for CATE(x)',ii));
    filename = sprintf('JTPACATE_FS%d',ii);
    savefig(filename);
end

%% Treatment effects for a few individuals
% Just look at first individual within each level of treatment effect
unique_te = unique(teplJTPA);
nte = numel(unique_te);

indlbs = zeros(nte,2+size(TUteJTPAlb,2));
indubs = zeros(nte,2+size(TUteJTPAub,2));
indpes = zeros(nte,2);
for jj = 1:nte
    indjj = find(teplJTPA == unique_te(jj));
    indlbs(jj,:) = [teolsJTPA(indjj(1)) - 1.96*steolsJTPA(indjj(1)) , ...
        teplJTPA(indjj(1)) - 1.96*steplJTPA(indjj(1)) , TUteJTPAlb(indjj(1),:)];
    indubs(jj,:) = [teolsJTPA(indjj(1)) + 1.96*steolsJTPA(indjj(1)) , ...
        teplJTPA(indjj(1)) + 1.96*steplJTPA(indjj(1)) , TUteJTPAub(indjj(1),:)];
    indpes(jj,:) = [teolsJTPA(indjj(1)) , teplJTPA(indjj(1))];
end

% Make plots for these individuals
for jj = 1:nte
    figure; plot((0:9),indlbs(jj,1)*ones(1,10),'r',(0:9),indubs(jj,1)*ones(1,10),'r');
    hold on; plot((0:9),indpes(jj,1)*ones(1,10),'r--');
    hold on; plot((0:9),indlbs(jj,2:end),'b',(0:9),indubs(jj,2:end),'b');
    hold on; plot((0:9),indpes(jj,2)*ones(1,10),'b--');
    hold on; plot((0:9),zeros(1,10),'k');
    title(sprintf('95% Interval Estimate for CATE for Individual %d',jj));
    xlabel('Number of Additional Variables');
    ylabel('CATE');
    filename = sprintf('JTPAind%d',jj);
    saveas(gcf, filename, 'jpeg');
end

%% Look at Wald test of all variables interacted with treatment have 0 coefficients
testset = ((px+3):(2*px+2))';
r = zeros(size(bolsJTPA));

% OLS
WOLS = (bolsJTPA(testset) - r(testset))'*(VolsJTPA(testset,testset)\(bolsJTPA(testset) - r(testset)));
pvalOLS = 1-chi2cdf(WOLS,numel(testset));

% Post-lasso
testpl = intersect(testset,usepl);
WPL = (bplJTPA(testpl) - r(testpl))'*(VplJTPA(testpl,testpl)\(bplJTPA(testpl) - r(testpl)));
pvalPL = 1-chi2cdf(WPL,numel(testpl));

% Forward selection
[FSpval,FSW,FSdf] = fsel_wald(Z,y,testset,r,usepl,ceil(log(n))+1);

%% Save results again
save JTPA_CATE ;

